{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyskl.smp import *\n",
    "from pyskl.models import build_model\n",
    "import torch\n",
    "from mmcv import load, dump\n",
    "import copy as cp\n",
    "from collections import OrderedDict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_cfg = dict(\n",
    "    type='MMRecognizer3D',\n",
    "    backbone=dict(\n",
    "        type='RGBPoseConv3D',\n",
    "        speed_ratio=4, \n",
    "        channel_ratio=4, \n",
    "        pose_pathway=dict(\n",
    "            num_stages=3, \n",
    "            stage_blocks=(4, 6, 3),\n",
    "            lateral=True,\n",
    "            lateral_inv=True,\n",
    "            lateral_infl=16,\n",
    "            lateral_activate=(0, 1, 1),\n",
    "            in_channels=17,\n",
    "            base_channels=32,\n",
    "            out_indices=(2, ),\n",
    "            conv1_kernel=(1, 7, 7),\n",
    "            conv1_stride=(1, 1),\n",
    "            pool1_stride=(1, 1),\n",
    "            inflate=(0, 1, 1),\n",
    "            spatial_strides=(2, 2, 2),\n",
    "            temporal_strides=(1, 1, 1))),\n",
    "    cls_head=dict(\n",
    "        type='RGBPoseHead',\n",
    "        in_channels=(2048, 512),\n",
    "        num_classes=60,\n",
    "        dropout=0.5),\n",
    "    test_cfg = dict(average_clips='prob'))\n",
    "model = build_model(model_cfg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Here we use the provided RGB / Pose pretrained weights as an example\n",
    "download_file('https://download.openmmlab.com/mmaction/pyskl/ckpt/rgbpose_conv3d/pose_only.pth')\n",
    "download_file('https://download.openmmlab.com/mmaction/pyskl/ckpt/rgbpose_conv3d/rgb_only.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "rgb_ckpt = torch.load('rgb_only.pth', map_location='cpu')['state_dict']\n",
    "pose_ckpt = torch.load('pose_only.pth', map_location='cpu')['state_dict']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "rgb_ckpt = {k.replace('backbone', 'backbone.rgb_path').replace('fc_cls', 'fc_rgb'): v for k, v in rgb_ckpt.items()}\n",
    "pose_ckpt = {k.replace('backbone', 'backbone.pose_path').replace('fc_cls', 'fc_pose'): v for k, v in pose_ckpt.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "old_ckpt = {}\n",
    "old_ckpt.update(rgb_ckpt)\n",
    "old_ckpt.update(pose_ckpt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The difference is in dim-1\n",
    "def padding(weight, new_shape):\n",
    "    new_weight = weight.new_zeros(new_shape)\n",
    "    new_weight[:, :weight.shape[1]] = weight\n",
    "    return new_weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "ckpt = cp.deepcopy(old_ckpt)\n",
    "name = 'backbone.rgb_path.layer3.0.conv1.conv.weight'\n",
    "ckpt[name] = padding(ckpt[name], (256, 640, 3, 1, 1))\n",
    "name = 'backbone.rgb_path.layer3.0.downsample.conv.weight'\n",
    "ckpt[name] = padding(ckpt[name], (1024, 640, 1, 1, 1))\n",
    "name = 'backbone.rgb_path.layer4.0.conv1.conv.weight'\n",
    "ckpt[name] = padding(ckpt[name], (512, 1280, 3, 1, 1))\n",
    "name = 'backbone.rgb_path.layer4.0.downsample.conv.weight'\n",
    "ckpt[name] = padding(ckpt[name], (2048, 1280, 1, 1, 1))\n",
    "name = 'backbone.pose_path.layer2.0.conv1.conv.weight'\n",
    "ckpt[name] = padding(ckpt[name], (64, 160, 3, 1, 1))\n",
    "name = 'backbone.pose_path.layer2.0.downsample.conv.weight'\n",
    "ckpt[name] = padding(ckpt[name], (256, 160, 1, 1, 1))\n",
    "name = 'backbone.pose_path.layer3.0.conv1.conv.weight'\n",
    "ckpt[name] = padding(ckpt[name], (128, 320, 3, 1, 1))\n",
    "name = 'backbone.pose_path.layer3.0.downsample.conv.weight'\n",
    "ckpt[name] = padding(ckpt[name], (512, 320, 1, 1, 1))\n",
    "ckpt = OrderedDict(ckpt)\n",
    "torch.save({'state_dict': ckpt}, 'rgbpose_conv3d_init.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "_IncompatibleKeys(missing_keys=['backbone.rgb_path.layer2_lateral.conv.weight', 'backbone.rgb_path.layer3_lateral.conv.weight', 'backbone.pose_path.layer1_lateral.conv.weight', 'backbone.pose_path.layer1_lateral.bn.weight', 'backbone.pose_path.layer1_lateral.bn.bias', 'backbone.pose_path.layer1_lateral.bn.running_mean', 'backbone.pose_path.layer1_lateral.bn.running_var', 'backbone.pose_path.layer2_lateral.conv.weight', 'backbone.pose_path.layer2_lateral.bn.weight', 'backbone.pose_path.layer2_lateral.bn.bias', 'backbone.pose_path.layer2_lateral.bn.running_mean', 'backbone.pose_path.layer2_lateral.bn.running_var'], unexpected_keys=[])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.load_state_dict(ckpt, strict=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5 (default, Sep  3 2020, 21:29:08) [MSC v.1916 64 bit (AMD64)]"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "e9762a6b93aa21d4aa12033e45e3b754f20c88c84120cffc73bc7fccd60bfa55"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
